# import packages and modules
import numpy as np
import pandas as pd
import os
import sys
import pickle
from joblib import dump, load
import yaml
import statsmodels.api as sm
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LinearRegression as reg
import pdb
from pprint import pprint
from statistics import mean
import random
random.seed(50)
import math
import argparse
sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
from gpd_test import *
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import cm
import matplotlib.font_manager as fm
import matplotlib
import sklearn.metrics as metrics
import seaborn as sns
import pdb
sys.path.insert(1,'/REDACTED/fairness/code/utilities/')
from trajectoryPlotters import *

defaultaxpayer_id= '/REDACTED/data/modeled_refactor_temp/'
defaultout = '/REDACTED/'

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/utilities/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name
plt.rcParams['mathtext.default'] = 'regular'

stanford_palette = ['#f0f4f5', '#d0d8da', '#aabeC6', '#009abb', '#007c92', '#09425a',
                   '#c7d1c6', '#80982a', '#556222', '#b96d12', '#53284f', '#5e3032',
                   '#8c1515', '#dddddd', '#cccccc', '#999999', '#666666', '#333333',
                   '#000000']
color_codes = ['background blue', 'gray blue', 'dark gray-blue', 'accent blue', 'link blue', 'dark blue',
              'light green', 'bright green', 'dark green', 'orange', 'purple', 'maroon',
              'cardinal red', 'line gray', 'light gray', 'gray', 'med gray', 'text gray',
              'black']

cdict = dict(zip(color_codes, [mcolors.to_rgba(c) for c in stanford_palette]))

mcolors.get_named_colors_mapping().update(cdict)

###############################################
#### GET PLOT DICTIONARIES
###############################################

## Status Quo
with open('/REDACTED/fairness/code/rf/data/status_quo.pickle','rb') as f:
    status_quo_point = pickle.load(f)

### Overclaiming Oracle
with open(defaultaxpayer_id + 'eitc_unres_oracle_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    unres_eitc_oracle_dep_database_ref_cred = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_oracle_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    res_eitc_oracle_dep_database_ref_cred = pickle.load(f)

### Refundable Credit - REGRESSOR
with open(defaultaxpayer_id + 'eitc_unres_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    unres_eitc_reg_dep_database_ref_cred = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    res_eitc_reg_dep_database_ref_cred = pickle.load(f)


### Oracle
with open(defaultaxpayer_id + 'eitc_unres_oracle_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    unres_eitc_oracle_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_oracle_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    res_eitc_oracle_dep_database = pickle.load(f)


### Regressor
with open(defaultaxpayer_id + 'eitc_unres_reg_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    unres_eitc_reg_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_reg_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    res_eitc_reg_dep_database = pickle.load(f)

                  
# get quantities of interest from plot dictionaries
file_list = [res_eitc_oracle_dep_database_ref_cred, unres_eitc_oracle_dep_database_ref_cred,
             res_eitc_reg_dep_database_ref_cred, unres_eitc_reg_dep_database_ref_cred,
             res_eitc_oracle_dep_database, unres_eitc_oracle_dep_database,
             res_eitc_reg_dep_database, unres_eitc_reg_dep_database]

string_list = ['res_eitc_oracle_dep_database_ref_cred', 'unres_eitc_oracle_dep_database_ref_cred',
               'res_eitc_reg_dep_database_ref_cred', 'unres_eitc_reg_dep_database_ref_cred',
               'res_eitc_oracle_dep_database', 'unres_eitc_oracle_dep_database',
               'res_eitc_reg_dep_database', 'unres_eitc_reg_dep_database']

data_dict = {}

print(len(file_list))
print(len(string_list))

for i in range(len(file_list)):
    print(i)
    data_dict[string_list[i]] = getAllQuantities(file_list[i], bootstrap = True)

## add makeTrajectoryPlot function so we can fiddle with status quo label
def makeTrajectoryPlot(modelnames,
                        quantities,
                        status_quo_point,
                        colors,
                        linestyles,
                        trajname,
                        plot_sq=True,
                        dict_x_varb='bootstrap_revenue',
                        dict_y_varb='full_chen_fair',
                        dict_se_varb='full_chen_fair',
                        sq_y_varb='eitc_chen_fair',
                        x_lab='Detected Underreportaxpayer_idg ($ Millions)',
                        y_lab='Disparity (percentage points)',
                        activity_code=None,
                        out=defaultout,
                        eitc=False,
                        offset=False,
                        annotatePct=True,
                        annotateModelName=False,
                        annotationcoords=[[0,0]],
                        ref_pt_offset=[False],
                        sq_coords_eitc=[500,-0.001],
                        sq_coords_pop=[1000,-0.0005],
                        plot_error = True,
                        title_lab = ''):
    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    fig.suptitle('', fontsize=24)
    plt.xlabel(x_lab, fontsize=20)
    plt.ylabel(y_lab, fontsize=20)
    ## plot oracle
    #xs =  revenues
    xs = [quantities[i][dict_x_varb + '_mean'] for i in range(len(quantities))]
    #ys = disparities
    ys = [quantities[i][dict_y_varb + '_mean'] for i in range(len(quantities))]
    errs = [quantities[i][dict_se_varb + '_std'] for i in range(len(quantities))]
    models = modelnames
    label = [(x/100) for x in range(1,301)]
    ax.set_ylabel(y_lab, fontsize=20)
    ax.set_xlabel(x_lab, fontsize=20)
    ax.set_title(str(title_lab), fontsize=20)
    ax.tick_params(labelsize=16)
    ax.axhline(y=0, color='gray')
    if eitc == True and activity_code is None and plot_sq is True:
        x = status_quo_point['eitc_rev']/1000000
        y = status_quo_point['eitc_' + sq_y_varb]
    elif eitc == False and activity_code is None and plot_sq is True:
        x = status_quo_point['overall_rev']/1000000
        y = status_quo_point['overall_' + sq_y_varb]
    elif activity_code is not None and plot_sq is True:
        x = status_quo_point[str(activity_code) + '_rev']/1000000
        y = status_quo_point[str(activity_code) + '_' + sq_y_varb]
    #plt.plot([plt.xlim()[0],x],[y,y],color='red',linestyle='--')
    #plt.plot([x,x],[0,y],color='red',linestyle='--')
    if plot_sq is True:
        ax.axhline(y=y,color='red',linestyle='dotted')
        ax.axvline(x=x,color='red',linestyle='dotted')
        ax.plot(x,y, 'kx', markersize=10)
    ## scale up y axis to percentages (instead of decimals)
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)
    if eitc == True and activity_code is None and plot_sq is True:
        ax.annotate('Status quo\n    ' + str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x, y), xytext=(x+sq_coords_eitc[0], y+sq_coords_eitc[1] + 0.003), size=16)
    elif eitc == False and activity_code is None and plot_sq is True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['overall_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
    elif activity_code is not None and plot_sq is True:
        ax.annotate('Status quo, ' + str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
    for i in range(len(models)):
        #x = np.array(quantities[i]['rev_mean'])
        #y = np.array(quantities[i]['fair_mean'])
        x = np.array(xs[i])
        print(x)
        y = np.array(ys[i])
        print(y)
        #errs = 
        ax.plot(x, y, label=models[i], color=colors[i], linestyle=linestyles[i])
        yerr = np.array(errs[i])
        yerr = yerr.astype(np.float)
        if plot_error == True:
            ax.fill_between(x, y-(yerr*1.96), y+(yerr*1.96), alpha=0.2, edgecolor='k', facecolor=colors[i])
        for j, txt in enumerate(label):
            if (j+1)%100==0 and offset==False:
                ax.annotate(str(math.floor(txt)) + '%', xy=(x[j]+50,y[j]), fontsize=16)
                ax.plot(x[j],y[j],'ko',markersize=5)
            elif (j+1)%100==0 and offset==True:
                dist=(float(i)%2)*0.001
                print(dist)
                ax.annotate(str(math.floor(txt)) + '%', xy=(x[j]+50,y[j]-dist), fontsize=16)
                ax.plot(x[j],y[j],'ko',markersize=5)
        if annotatePct and eitc == True and activity_code is None:
            idx=math.floor(round(status_quo_point['eitc_budget'],4)*10000 - 1)
            ax.plot(x[idx],y[idx], 'ko', markersize=10)
            if ref_pt_offset[i]==True:
                ax.annotate(str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]+75, y[idx]-0.0015), size=16)
            else:
                ax.annotate(str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]+75, y[idx]+0.00025), size=16)
        elif annotatePct and eitc == False and activity_code is None:
            idx=math.floor(round(status_quo_point['overall_budget'],4)*10000 - 1)
            ax.plot(x[idx],y[idx],'ko',markersize=10)
            if ref_pt_offset[i] == True:
                ax.annotate(str(round(status_quo_point['overall_budget']*100, 2)) + '%',xy=(x[idx],y[idx]),xytext=(x[idx]+100,y[idx]-0.001),size=16) 
            else:
                ax.annotate(str(round(status_quo_point['overall_budget']*100, 2)) + '%',xy=(x[idx],y[idx]),xytext=(x[idx]+100,y[idx]+0.00075),size=16) 
        elif annotatePct and activity_code is not None:
            idx=math.floor(round(status_quo_point[str(activity_code) + '_budget']*10000 -1, 2))
            ax.plot(x[idx],y[idx], 'ko', markersize=10)
            if ref_pt_offset[i]==True:
                ax.annotate(str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]-200, y[idx]-0.001), size=16)
            else:
                ax.annotate(str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]-200, y[idx]+0.00075), size=16)
        if annotateModelName:
            #pdb.set_trace()
            ax.annotate(modelnames[i],(annotationcoords[i][0],annotationcoords[i][1]), size=20)
    plt.savefig(out+trajname+'.png')
    return fig





###############################################
#### Trajectory Plots
###############################################

###############################################
#### Fig 8: Detected Underreportaxpayer_idg and Disparity by Algorithm
###############################################


### unrestricted probabilistic refundable, classifier, regressor, oracle 
makeTrajectoryPlot(modelnames=['Refundable Credit Oracle', 'Total Underreportaxpayer_idg\n          Oracle', 'Refundable Credit Prediction', 'Total Underreportaxpayer_idg\n        Prediction'],
                    quantities=[data_dict['unres_eitc_oracle_dep_database_ref_cred'], data_dict['unres_eitc_oracle_dep_database'], data_dict['unres_eitc_reg_dep_database_ref_cred'], data_dict['unres_eitc_reg_dep_database']],
                    status_quo_point = status_quo_point,
                    colors=['link blue', 'dark blue', 'light purple', 'dark purple'],
                    linestyles=['solid', 'solid', 'solid', 'solid'],
                    trajname='regs_oracles_final',
                    dict_x_varb = 'bootstrap_revenue',
                    dict_y_varb='bootstrap_chen_fair',
                    dict_se_varb='bootstrap_chen_fair',
                    sq_y_varb='chen_fair',
                    #x_lab = 'Detected Overclaiming ($ Millions)',
                    y_lab='Disparity (probabilistic, percentage points)',
                    plot_sq = True,
                    activity_code = None,
                    out=defaultout,
                    eitc=True,
                    offset=True,
                    annotatePct=True,
                    annotateModelName=True,
                    ref_pt_offset=[True, True, True, True],
                    sq_coords_eitc=[-2000,-0.002],
                    sq_coords_pop =[-1000,0.002],
                    annotationcoords=[[8050,0.023], [9000,0.001], [4750,0.035], [7200,-0.018]],
                    plot_error = True)

plt.close()